/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "graph_optimizer/heavy_format_propagation/heavy_format_supportformats_updater.h"

namespace fe {
HeavyFormatSupportFormatsUpdater::HeavyFormatSupportFormatsUpdater(FormatDtypeQuerierPtr format_dtype_querier_ptr,
                                                                   FormatDtypeSetterPtr format_dtype_setter_ptr)
    : format_dtype_querier_ptr_(format_dtype_querier_ptr), format_dtype_setter_ptr_(format_dtype_setter_ptr) {}

HeavyFormatSupportFormatsUpdater::~HeavyFormatSupportFormatsUpdater() {}

Status HeavyFormatSupportFormatsUpdater::UpdateSupportFormats(const ge::NodePtr& node_ptr,
                                                              const OpKernelInfoPtr& op_kernel_info_ptr,
                                                              const std::vector<IndexNameMap>& tensor_map,
                                                              const HeavyFormatInfo& heavy_format_info) {
  auto op_desc_ptr = node_ptr->GetOpDesc();
  auto op_name = op_desc_ptr->GetName();
  auto op_type = op_desc_ptr->GetType();
  // 1. If the heavy_format is not fz/fz_3d, or the op is not dynamic_format and op.patter=Broadcast,
  // no need to update the support_formats
  if (!IsFzRelaFormat(heavy_format_info) || !IsSelectFormatOrBroadcast(op_desc_ptr, op_kernel_info_ptr)) {
    return SUCCESS;
  }

  // 2. get the support formats
  vector<vector<InputOrOutputInfoPtr>> input_and_output_kernel;
  input_and_output_kernel.emplace_back();
  input_and_output_kernel.emplace_back();
  Status ret = GetAllInputAndOutputKernelInfo(op_kernel_info_ptr, node_ptr, tensor_map, input_and_output_kernel);
  if (ret != SUCCESS) {
    return FAILED;
  }
  if (input_and_output_kernel.size() != INPUT_OUTPUT_INDEX_BOTTOM) {
    FE_LOGW("Size of input kernel vector %zu is not correct for node %s.", input_and_output_kernel.size(),
            node_ptr->GetName().c_str());
    return FAILED;
  }

  std::vector<InputOrOutputInfoPtr> input_or_output_info_vec =
      heavy_format_info.is_input ? input_and_output_kernel[INPUT_INDEX] : input_and_output_kernel[OUTPUT_INDEX];
  InputOrOutputInfoPtr input_or_output_info = input_or_output_info_vec.at(heavy_format_info.anchor_index);
  vector<ge::Format> kernel_formats;
  if (format_dtype_querier_ptr_->GetSupportFormats(op_kernel_info_ptr, input_or_output_info, *(op_desc_ptr.get()),
                                                   kernel_formats) != SUCCESS) {
    return FAILED;
  }

  // 3. update support formats and dtypes
  auto propaga_heavy_format = static_cast<ge::Format>(
      ge::GetFormatFromSub(heavy_format_info.expected_heavy_format, heavy_format_info.sub_format));
  if (!NeedUpdateSupportFormats(op_desc_ptr, heavy_format_info, kernel_formats, propaga_heavy_format)) {
    return SUCCESS;
  }

  FE_LOGD("Op[name=%s,type=%s]: need to update support formats, propaga_heavy_format=%s for %s.", op_name.c_str(),
          op_type.c_str(), FormatToStr(propaga_heavy_format).c_str(), input_or_output_info->GetUniqueName().c_str());
  ret = format_dtype_setter_ptr_->SetSupportFormatDtypeByNode(node_ptr, heavy_format_info);
  if (ret != SUCCESS) {
    REPORT_FE_ERROR("[GraphOptJdgInst][SptFmtUpDtr][UptSptFmt] Op[name=%s,type=%s]: failed to set the support formats \
                    and dtypes.", op_name.c_str(), op_type.c_str());
    return FAILED;
  }
  (void)ge::AttrUtils::SetStr(op_desc_ptr, ATTR_NAME_FE_PROPAGAT_HEAVY_FORMAT,
                              ge::TypeUtils::FormatToSerialString(propaga_heavy_format));

  return SUCCESS;
}

bool HeavyFormatSupportFormatsUpdater::IsFzRelaFormat(const HeavyFormatInfo& heavy_format_info) {
  return std::find(FE_GROUP_RELA_FORMAT_VECTOR.begin(), FE_GROUP_RELA_FORMAT_VECTOR.end(),
                   heavy_format_info.expected_heavy_format) != FE_GROUP_RELA_FORMAT_VECTOR.end();
}

bool HeavyFormatSupportFormatsUpdater::IsSelectFormatOrBroadcast(const ge::OpDescPtr& op_desc_ptr,
                                                                 const OpKernelInfoPtr& op_kernel_info_ptr) {
  bool is_op_pattern_broadcast =
      format_dtype_setter_ptr_->IsOpPatternBroadcast(op_kernel_info_ptr, *(op_desc_ptr.get()));
  return is_op_pattern_broadcast;
}

bool HeavyFormatSupportFormatsUpdater::NeedUpdateSupportFormats(const ge::OpDescPtr& op_desc_ptr,
                                                                const HeavyFormatInfo& heavy_format_info,
                                                                const vector<ge::Format>& kernel_formats,
                                                                ge::Format propaga_heavy_format) {
  if (!ge::AttrUtils::HasAttr(op_desc_ptr, ATTR_NAME_FE_PROPAGAT_HEAVY_FORMAT)) {
    if (std::find(kernel_formats.begin(), kernel_formats.end(), heavy_format_info.expected_heavy_format) ==
            kernel_formats.end() ||
        heavy_format_info.sub_format <= 1) {
      return false;
    }
    FE_LOGD("Op[name=%s,type=%s]: no attr %s.", op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str(),
            ATTR_NAME_FE_PROPAGAT_HEAVY_FORMAT.c_str());
    return true;
  }

  string update_format_str;
  (void)ge::AttrUtils::GetStr(op_desc_ptr, ATTR_NAME_FE_PROPAGAT_HEAVY_FORMAT, update_format_str);
  auto propaga_format_str = ge::TypeUtils::FormatToSerialString(propaga_heavy_format);
  if (update_format_str == propaga_format_str) {
    FE_LOGD("Op[name=%s,type=%s]: the attr %s %s is equal to propaga_format %s.", op_desc_ptr->GetName().c_str(),
            op_desc_ptr->GetType().c_str(), ATTR_NAME_FE_PROPAGAT_HEAVY_FORMAT.c_str(), update_format_str.c_str(),
            propaga_format_str.c_str());
    return false;
  }
  return true;
}
}  // namespace fe